BiT: Robustly Binarized Multi-Distilled Transformer
143
FIGURE 5.16
Overview of BiT. A transformer block contains the multi-head self-attention and feed-
forward network. All the weights are binarized to {−1, 1} in the Embedding/Fully-
Connected layers and binarize activations to {0, 1} for ReLU/Softmax outputs and to
{−1, 1} for other layers.
distilling higher precision models into lower precision students. They are introduced in detail
as follows.
5.10.1
Two-Set Binarization Scheme
In contrast to CNNs on images where activations exhibit comparable distributions, different
activations in transformer blocks are performing different functionalities, and thus vary in
their output distributions. In particular, these activations can be divided into two cate-
gories: the activations after Softmax/ReLU layer that contains positive values only and the
remaining activations with both positive and negative values (e.g., after matrix multiplica-
tion). If we denote by XR the vector of activation values, then the two cases are Xi
R ∈R+
and Xi
R ∈R respectively.
For the former set, mapping to the binary levels {−1, 1} would result in a severe dis-
tribution mismatch. Therefore, the authors instead mapped non-negative activation layers
to ˆXB ∈{0, 1}n and binarize activation layers with XR ∈Rn to ˆXB ∈{−1, 1}n, shown in
Fig. 5.16. BiBERT [195] also suggests binarizing attention to {0, 1}, but with bool function
replacing SoftMax, while the authors empirically found that simply binarizing attentions
after SoftMax to {0, 1} works better and binarizing ReLU output to {0, 1} instead of {−1, 1}
brings further improvements.
Additionally, they applied a layer-wise scaling factor to binarized activations to reduce
the binarization error, i.e., XB = α ˆXB. The optimal values of α are different for the
ˆXB ∈{0, 1}n and ˆXB ∈{−1, 1}n cases and can be calculated by minimizing the l2 error:
J (α) = ||XR −α ˆXB||2
α∗= arg min
α∈R+
J (α)
(5.35)
Following XNOR-Net [199], by expanding Eq. 5.35, we have
J (α) = α2 ˆ
XB
T ˆXB −2αXR
T ˆXB + XR
T XR
(5.36)